Sowing dates prioritization using Causal Random Forest and Policy Trees: Evidence from Agronomic Trials in Eastern India

Author

Maxwell Mkondiwa

1 Introduction

In this notebook, I use a causal machine learning estimator, i.e., multi-armed causal random forest with augmented inverse propensity score weights (Athey et al 2019), to estimate conditional average treatment effects (CATES) for agronomic practices. These CATEs are estimated for each individual farm thereby providing personalized estimates of the potential effectiveness of the practices. I then use a debiased robust estimator in a policy tree optimization (Athey and Wager 2021) to generate optimal recommendations in the form of agronomic practices that maximize potential yield gains.

2 Preliminaries

load("CSISA_KVK_Public_Workspace.RData")

CSISA_KVKestim=CSISA_KVK
table(CSISA_KVKestim$SowingSchedule) 

 T1  T2  T3  T4  T5 
408 858 806 881 695 
CSISA_KVKestim$SowingSchedule=ordered(CSISA_KVKestim$SowingSchedule,levels=c("T5","T4","T3","T2","T1"))

#Create dummies for some categorical variables
library(fastDummies)
CSISA_KVKestim <- fastDummies::dummy_cols(CSISA_KVKestim, select_columns =c("VarietyClass","SoilType","CropEstablishment","Year","District"))

2.1 Graphics

# Bar graphs showing percentage of farmers adopting these practices  

library(tidyverse) 
library(ggplot2)  

bar_chart=function(dat,var){   dat|>     drop_na({{var}})|>     mutate({{var}}:=factor({{var}})|>fct_infreq())|>     ggplot()+     geom_bar(aes(y={{var}}),fill="dodgerblue4")+     theme_minimal(base_size = 16) }   

sow_plot=bar_chart(CSISA_KVKestim,SowingSchedule)+labs(y="Sowing dates") 
 
sow_plot

library(ggpubr) 
library(tidyverse) 

#Sowing dates 

SowingDate_Options_Errorplot=   CSISA_KVKestim%>%   
  drop_na(SowingSchedule) %>%   
  ggerrorplot(x = "SowingSchedule", y = "GrainYield",add = "mean", error.plot = "errorbar", color="steelblue", ggtheme=theme_bw())+   
  labs(x="Sowing date options",y="Wheat yield (t/ha)")+   
  theme_minimal(base_size = 16)+
  coord_flip()  
Warning: The `fun.y` argument of `stat_summary()` is deprecated as of ggplot2 3.3.0.
i Please use the `fun` argument instead.
i The deprecated feature was likely used in the ggpubr package.
  Please report the issue at <https://github.com/kassambara/ggpubr/issues>.
Warning: The `fun.ymin` argument of `stat_summary()` is deprecated as of ggplot2 3.3.0.
i Please use the `fun.min` argument instead.
i The deprecated feature was likely used in the ggpubr package.
  Please report the issue at <https://github.com/kassambara/ggpubr/issues>.
Warning: The `fun.ymax` argument of `stat_summary()` is deprecated as of ggplot2 3.3.0.
i Please use the `fun.max` argument instead.
i The deprecated feature was likely used in the ggpubr package.
  Please report the issue at <https://github.com/kassambara/ggpubr/issues>.
SowingDate_Options_Errorplot 

2.2 Descriptives

library(fBasics)

summ_stats <- fBasics::basicStats(CSISA_KVKestim[,c("GrainYield","VarietyClass_LDV","SoilType_Heavy","SoilType_Medium","SoilType_Low","CropEstablishment_CT","CropEstablishment_CT-line","CropEstablishment_ZT")]) 

summ_stats <- as.data.frame(t(summ_stats)) 

# Rename some of the columns for convenience 

summ_stats <- summ_stats[c("Mean", "Stdev", "Minimum", "1. Quartile", "Median",  "3. Quartile", "Maximum")] %>%   
rename("Lower quartile" = '1. Quartile', "Upper quartile"= "3. Quartile")  

summ_stats 
                              Mean    Stdev Minimum Lower quartile Median
GrainYield                4.067177 1.035497    1.39           3.23   4.06
VarietyClass_LDV          0.710526 0.453580    0.00           0.00   1.00
SoilType_Heavy            0.214090 0.410246    0.00           0.00   0.00
SoilType_Medium           0.763980 0.424693    0.00           1.00   1.00
SoilType_Low              0.021930 0.146475    0.00           0.00   0.00
CropEstablishment_CT      0.020285 0.140993    0.00           0.00   0.00
CropEstablishment_CT.line 0.004386 0.066090    0.00           0.00   0.00
CropEstablishment_ZT      0.975329 0.155142    0.00           1.00   1.00
                          Upper quartile Maximum
GrainYield                          4.84    6.74
VarietyClass_LDV                    1.00    1.00
SoilType_Heavy                      0.00    1.00
SoilType_Medium                     1.00    1.00
SoilType_Low                        0.00    1.00
CropEstablishment_CT                0.00    1.00
CropEstablishment_CT.line           0.00    1.00
CropEstablishment_ZT                1.00    1.00

3 Causal Random Forest Model

library(grf)
library(policytree)

CSISA_KVKestim_sow=subset(CSISA_KVKestim, select=c("SowingSchedule","GrainYield","VarietyClass_LDV","SoilType_Heavy","SoilType_Medium","SoilType_Low","CropEstablishment_CT","CropEstablishment_CT-line","CropEstablishment_ZT","wc2.1_30s_elev","nitrogen_0.5cm","sand_0.5cm", "soc_5.15cm","Latitude","Longitude"))

library(tidyr)
CSISA_KVKestim_sow=CSISA_KVKestim_sow %>% drop_na()


Y_cf_sowing=as.vector(CSISA_KVKestim_sow$GrainYield)
## Causal random forest -----------------

X_cf_sowing=subset(CSISA_KVKestim_sow, select=c("VarietyClass_LDV","SoilType_Heavy","SoilType_Medium","SoilType_Low","CropEstablishment_CT","CropEstablishment_CT-line","CropEstablishment_ZT","wc2.1_30s_elev","nitrogen_0.5cm","sand_0.5cm", "soc_5.15cm","Latitude","Longitude"))


W_cf_sowing <- as.factor(CSISA_KVKestim_sow$SowingSchedule)

W.multi_sowing.forest <- probability_forest(X_cf_sowing, W_cf_sowing,
  equalize.cluster.weights = FALSE,
  seed = 2
)
W.hat.multi.all_sowing <- predict(W.multi_sowing.forest, estimate.variance = TRUE)$predictions



Y.multi_sowing.forest <- regression_forest(X_cf_sowing, Y_cf_sowing,
  equalize.cluster.weights = FALSE,
  seed = 2
)

print(Y.multi_sowing.forest)
GRF forest object of type regression_forest 
Number of trees: 2000 
Number of training samples: 3628 
Variable importance: 
    1     2     3     4     5     6     7     8     9    10    11    12    13 
0.641 0.002 0.001 0.000 0.002 0.000 0.001 0.101 0.008 0.009 0.040 0.036 0.158 
varimp.multi_sowing <- variable_importance(Y.multi_sowing.forest)
Y.hat.multi.all_sowing <- predict(Y.multi_sowing.forest, estimate.variance = TRUE)$predictions



multi_sowing.forest <- multi_arm_causal_forest(X = X_cf_sowing, Y = Y_cf_sowing, W = W_cf_sowing ,W.hat=W.hat.multi.all_sowing,Y.hat=Y.hat.multi.all_sowing,seed=2) 

varimp.multi_sowing_cf <- variable_importance(multi_sowing.forest)

multi_sowing_ate=average_treatment_effect(multi_sowing.forest, method="AIPW")
Warning in get_scores.multi_arm_causal_forest(forest, subset = subset, debiasing.weights = debiasing.weights, : Estimated treatment propensities take values very close to 0 or 1 meaning some estimates may not be well identified. In particular, the minimum propensity estimates for each arm is
T5: 0 T4: 0.004 T3: 0.002 T2: 0 T1: 0
and the maximum is
T5: 0.893 T4: 0.873 T3: 0.834 T2: 0.922 T1: 0.941.
multi_sowing_ate
         estimate    std.err contrast outcome
T4 - T5 0.3666129 0.04368794  T4 - T5     Y.1
T3 - T5 0.8570305 0.04010667  T3 - T5     Y.1
T2 - T5 0.8089736 0.22419766  T2 - T5     Y.1
T1 - T5 1.6852136 0.25117241  T1 - T5     Y.1
varimp.multi_sowing_cf <- variable_importance(multi_sowing.forest)
vars_sowing=c("VarietyClass_LDV","SoilType_Heavy","SoilType_Medium","SoilType_Low","CropEstablishment_CT","CropEstablishment_CT-line","CropEstablishment_ZT","wc2.1_30s_elev","nitrogen_0.5cm","sand_0.5cm", "soc_5.15cm","Latitude","Longitude")

## variable importance plot ----------------------------------------------------
varimpvars_sowing=as.data.frame(cbind(varimp.multi_sowing_cf,vars_sowing))
names(varimpvars_sowing)[1]="Variableimportance_sowing"
varimpvars_sowing$Variableimportance_sowing=formatC(varimpvars_sowing$Variableimportance_sowing, digits = 2, format = "f")
varimpvars_sowing$Variableimportance_sowing=as.numeric(varimpvars_sowing$Variableimportance_sowing)
varimpplotRF_sowing=ggplot(varimpvars_sowing,aes(x=reorder(vars_sowing,Variableimportance_sowing),y=Variableimportance_sowing))+
   geom_jitter(color="steelblue")+
   coord_flip()+
   labs(x="Variables",y="Variable importance")
 previous_theme <- theme_set(theme_bw(base_size = 16))
 varimpplotRF_sowing

# Policy tree --------------------------------------
DR.scores_sowing <- double_robust_scores(multi_sowing.forest)

tr_sowing <- policy_tree(X_cf_sowing, DR.scores_sowing, depth = 2) 
plot(tr_sowing)
tr_sowing3 <- hybrid_policy_tree(X_cf_sowing, DR.scores_sowing, depth = 3) 
tr_sowing3
policy_tree object 
Tree depth:  3 
Actions:  1: T5 2: T4 3: T3 4: T2 5: T1 
Variable splits: 
(1) split_variable: Longitude  split_value: 84.953 
  (2) split_variable: Longitude  split_value: 83.914 
    (4) split_variable: Latitude  split_value: 25.426 
      (8) * action: 3 
      (9) * action: 5 
    (5) split_variable: nitrogen_0.5cm  split_value: 2.1 
      (10) * action: 5 
      (11) * action: 4 
  (3) split_variable: wc2.1_30s_elev  split_value: 51 
    (6) split_variable: Latitude  split_value: 25.025 
      (12) * action: 3 
      (13) * action: 5 
    (7) split_variable: Longitude  split_value: 84.964 
      (14) * action: 2 
      (15) * action: 5 
plot(tr_sowing3)
tr_sowing4 <- hybrid_policy_tree(X_cf_sowing, DR.scores_sowing, depth = 4) 
tr_sowing4
policy_tree object 
Tree depth:  4 
Actions:  1: T5 2: T4 3: T3 4: T2 5: T1 
Variable splits: 
(1) split_variable: Longitude  split_value: 84.953 
  (2) split_variable: Longitude  split_value: 83.914 
    (4) split_variable: Latitude  split_value: 25.426 
      (8) split_variable: Longitude  split_value: 83.913 
        (16) * action: 3 
        (17) * action: 1 
      (9) split_variable: Latitude  split_value: 26.417 
        (18) * action: 5 
        (19) * action: 3 
    (5) split_variable: Longitude  split_value: 83.929 
      (10) split_variable: Longitude  split_value: 83.928 
        (20) * action: 4 
        (21) * action: 1 
      (11) split_variable: nitrogen_0.5cm  split_value: 2.1 
        (22) * action: 5 
        (23) * action: 4 
  (3) split_variable: wc2.1_30s_elev  split_value: 51 
    (6) split_variable: wc2.1_30s_elev  split_value: 43 
      (12) split_variable: soc_5.15cm  split_value: 14.5 
        (24) * action: 4 
        (25) * action: 5 
      (13) split_variable: Latitude  split_value: 25.025 
        (26) * action: 3 
        (27) * action: 5 
    (7) split_variable: wc2.1_30s_elev  split_value: 53 
      (14) split_variable: Longitude  split_value: 85.415 
        (28) * action: 2 
        (29) * action: 5 
      (15) split_variable: Longitude  split_value: 84.964 
        (30) * action: 2 
        (31) * action: 5 
plot(tr_sowing4)
tr_assignment_sowing=CSISA_KVKestim_sow

tr_assignment_sowing$depth2 <- predict(tr_sowing, X_cf_sowing)
table(tr_assignment_sowing$depth2)

   2    4    5 
  38  200 3390 
tr_assignment_sowing$depth3 <- predict(tr_sowing3, X_cf_sowing)
table(tr_assignment_sowing$depth3)

   2    3    4    5 
  38  188  200 3202 
tr_assignment_sowing$depth4 <- predict(tr_sowing4, X_cf_sowing)
table(tr_assignment_sowing$depth4)

   1    2    3    4    5 
  40   47  168  304 3069 

4 Policy learning algorithm for treatment assignment

library(rgdal)

tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2==1]="T5_16Dec"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2==2]="T4_15Dec"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2==3]="T3_30Nov"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2==4]="T2_20Nov"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2==5]="T1_10Nov"

tr_assignment_sowingsp= SpatialPointsDataFrame(cbind(tr_assignment_sowing$Longitude,tr_assignment_sowing$Latitude),data=tr_assignment_sowing,proj4string=CRS("+proj=longlat +datum=WGS84"))

library(mapview)
mapviewOptions(fgb = FALSE)
tr_assignment_sowingspmapview=mapview(tr_assignment_sowingsp,zcol="depth2_cat",layer.name="Recommended sowing dates")
tr_assignment_sowingspmapview

5 Distributional analysis

library(ggridges)
library(dplyr)
tau.multi_sowing.forest=predict(multi_sowing.forest, target.sample = "all",estimate.variance=TRUE)

tau.multi_sowing.forest=as.data.frame(tau.multi_sowing.forest)


tau.multi_sowing.forest_X=data.frame(CSISA_KVKestim_sow,tau.multi_sowing.forest)


# Ridges -------------------
tau.multi_sowing.forest_pred=tau.multi_sowing.forest[,1:4]

library(dplyr)
library(reshape2)
tau.multi_sowing.forest_pred=rename(tau.multi_sowing.forest_pred,"T4_15Dec - T5_16Dec"="predictions.T4...T5.Y.1")

tau.multi_sowing.forest_pred=rename(tau.multi_sowing.forest_pred,"T3_30Nov-T5_16Dec"="predictions.T3...T5.Y.1")

tau.multi_sowing.forest_pred=rename(tau.multi_sowing.forest_pred,"T2_20Nov-T5_16Dec"="predictions.T2...T5.Y.1")

tau.multi_sowing.forest_pred=rename(tau.multi_sowing.forest_pred,"T1_10Nov-T5_16Dec"="predictions.T1...T5.Y.1")


tau.multi_sowing.forest_pred_long=reshape2::melt(tau.multi_sowing.forest_pred[,1:4])

ggplot(tau.multi_sowing.forest_pred_long, aes(x=value, y=variable, fill = factor(stat(quantile)))) +
  stat_density_ridges(
    geom = "density_ridges_gradient", calc_ecdf = TRUE,
    quantiles = 4, quantile_lines = TRUE
  ) +
  scale_fill_viridis_d(name = "Quartiles")+
  theme_bw(base_size = 16)+labs(x="Wheat yield gain(t/ha)",y="Sowing date options")
Warning: `stat(quantile)` was deprecated in ggplot2 3.4.0.
i Please use `after_stat(quantile)` instead.
Warning: Using the `size` aesthetic with geom_segment was deprecated in ggplot2 3.4.0.
i Please use the `linewidth` aesthetic instead.

6 Transition matrix of the policy change

tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2 == 1] <- "T5_16Dec"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2 == 2] <- "T4_15Dec"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2 == 3] <- "T3_30Nov"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2 == 4] <- "T2_20Nov"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2 == 5] <- "T1_10Nov"


library(ggalluvial)
library(data.table)
tr_assignment_sowingDT = data.table(tr_assignment_sowing)
TransitionMatrix_sowing <- tr_assignment_sowingDT[, (sum <- .N), by = c("SowingSchedule", "depth2_cat")]
library(dplyr)
TransitionMatrix_sowing <- rename(TransitionMatrix_sowing, Freq = V1)

library(scales)
transitionmatrixplot_sowing <- ggplot(
    data = TransitionMatrix_sowing,
    aes(axis1 = SowingSchedule, axis2 = depth2_cat, y = Freq)
) +
    geom_alluvium(aes(fill = depth2_cat)) +
    geom_stratum() +
    # geom_text(stat="stratum", aes(label=after_stat(stratum),nudge_y =5))+
    geom_text(stat = "stratum", aes(label = paste(after_stat(stratum), percent(after_stat(prop))))) +
    scale_x_discrete(
        limits = c("SowingSchedule", "depth2_cat"),
        expand = c(0.15, 0.05)
    ) +
    scale_fill_viridis_d() +
    theme_void(base_size = 20) +
    theme(legend.position = "none")

transitionmatrixplot_sowing
Warning: Using the `size` aesthetic in this geom was deprecated in ggplot2 3.4.0.
i Please use `linewidth` in the `default_aes` field and elsewhere instead.